from collections import deque, defaultdict

from transitions.core import listify
from transitions.extensions.markup import HierarchicalMarkupMachine


_placeholder_body = "raise RuntimeError('This should be overridden')"


def generate_base_model(config):
    m = HierarchicalMarkupMachine(**config)
    triggers = set()
    markup = m.markup
    model_attribute = markup.get("model_attribute", "state")
    trigger_block = ""
    state_block = ""
    callback_block = ""

    callbacks = set(
        [cb for cb in markup["prepare_event"]]
        + [cb for cb in markup["before_state_change"]]
        + [cb for cb in markup["after_state_change"]]
        + [cb for cb in markup["on_exception"]]
        + [cb for cb in markup["on_final"]]
        + [cb for cb in markup["finalize_event"]]
    )

    for trans in markup["transitions"]:
        triggers.add(trans["trigger"])

    stack = [(markup["states"], markup["transitions"], "")]
    has_nested_states = any("children" in state for state in markup["states"])
    while stack:
        states, transitions, prefix = stack.pop()
        for state in states:
            state_name = state["name"]

            state_block += (
                f"    def is_{prefix}{state_name}(self{', allow_substates=False' if has_nested_states else ''})"
                f" -> bool: {_placeholder_body}\n"
            )
            if m.auto_transitions:
                state_block += (
                    f"    def to_{prefix}{state_name}(self) -> bool: {_placeholder_body}\n"
                    f"    def may_to_{prefix}{state_name}(self) -> bool: {_placeholder_body}\n"
                )

            state_block += "\n"
            for tran in transitions:
                triggers.add(tran["trigger"])
                new_set = set(
                    [cb for cb in tran.get("prepare", [])]
                    + [cb for cb in tran.get("conditions", [])]
                    + [cb for cb in tran.get("unless", [])]
                    + [cb for cb in tran.get("before", [])]
                    + [cb for cb in tran.get("after", [])]
                )
                callbacks.update(new_set)

            if "children" in state:
                stack.append((state["children"], state.get("transitions", []), prefix + state_name + "_"))

    for trigger_name in triggers:
        trigger_block += (
            f"    def {trigger_name}(self) -> bool: {_placeholder_body}\n"
            f"    def may_{trigger_name}(self) -> bool: {_placeholder_body}\n"
        )

    extra_params = "event_data: EventData" if m.send_event else "*args: Any, **kwargs: Any"
    for callback_name in callbacks:
        if isinstance(callback_name, str):
            callback_block += (f"    @abstractmethod\n"
                               f"    def {callback_name}(self, {extra_params}) -> Optional[bool]: ...\n")

    template = f"""# autogenerated by transitions
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING

if TYPE_CHECKING:
    from transitions.core import CallbacksArg, StateIdentifier, EventData


class BaseModel(metaclass=ABCMeta):
    {model_attribute}: "StateIdentifier" = ""
    def trigger(self, name: str) -> bool: {_placeholder_body}

{trigger_block}
{state_block}\
{callback_block}"""

    return template


def with_model_definitions(cls):
    add_model = getattr(cls, "add_model")

    def add_model_override(self, model, initial=None):
        self.model_override = True
        for model in listify(model):
            model = self if model == "self" else model
            for name, specs in TriggerPlaceholder.definitions.get(model.__class__, {}).items():
                for spec in specs:
                    if isinstance(spec, list):
                        self.add_transition(name, *spec)
                    elif isinstance(spec, dict):
                        self.add_transition(name, **spec)
                    else:
                        raise ValueError("Cannot add {} for event {} to machine", spec, name)
        add_model(self, model, initial)

    setattr(cls, 'add_model', add_model_override)
    return cls


class TriggerPlaceholder:
    definitions = defaultdict(lambda: defaultdict(list))

    def __init__(self, configs):
        self.configs = deque(configs)

    def __set_name__(self, owner, name):
        for config in self.configs:
            TriggerPlaceholder.definitions[owner][name].append(config)

    def __call__(self, *args, **kwargs):
        raise RuntimeError("Trigger was not initialized correctly!")


def event(*configs):
    return TriggerPlaceholder(configs)


def add_transitions(*configs):
    def _outer(trigger_func):
        if isinstance(trigger_func, TriggerPlaceholder):
            for config in reversed(configs):
                trigger_func.configs.appendleft(config)
        else:
            trigger_func = TriggerPlaceholder(configs)
        return trigger_func

    return _outer


def transition(source, dest=None, conditions=None, unless=None, before=None, after=None, prepare=None):
    return {"source": source, "dest": dest, "conditions": conditions, "unless": unless, "before": before,
            "after": after, "prepare": prepare}
